{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kSNEB-3orJ9C"
   },
   "source": [
    "# **Token Classification: Sparse Transfer Learning with the Python API**\n",
    "\n",
    "In this example, you will fine-tune a 90% pruned BERT model onto the WNUT NER dataset using SparseML's Hugging Face Integration.\n",
    "\n",
    "### **Sparse Transfer Learning Overview**\n",
    "\n",
    "Sparse Transfer Learning is very similiar to typical fine-tuning you are used to when training models. However, with Sparse Transfer Learning, we start the training process from a pre-sparsified checkpoint and maintain the sparsity structure while the fine tuning occurs.\n",
    "\n",
    "At the end, you will have a sparse model trained on your dataset, ready to be deployed with DeepSparse for GPU-class performance on CPUs!\n",
    "\n",
    "### **Pre-Sparsified BERT**\n",
    "SparseZoo, Neural Magic's open source repository of pre-sparsified models, contains a 90% pruned version of BERT, which has been sparsified on the upstream Wikipedia and BookCorpus datasets with the\n",
    "masked language modeling objective. [Check out the model card](https://sparsezoo.neuralmagic.com/models/nlp%2Fmasked_language_modeling%2Fobert-base%2Fpytorch%2Fhuggingface%2Fwikipedia_bookcorpus%2Fpruned90-none). We will use this model as the starting point for the transfer learning process.\n",
    "\n",
    "\n",
    "**Let's dive in!**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Y0WybTbssU0g"
   },
   "source": [
    "## **Installation**\n",
    "\n",
    "Install SparseML via `pip`.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "id": "Fzor2W16DJ2A"
   },
   "outputs": [],
   "source": [
    "!pip install sparseml[transformers]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_jY0SKdXFGO3"
   },
   "source": [
    "If you are running on Google Colab, restart the runtime after this step."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "id": "XXj0S5Jdq2M-"
   },
   "outputs": [],
   "source": [
    "import sparseml\n",
    "from sparsezoo import Model\n",
    "from sparseml.transformers.utils import SparseAutoModel\n",
    "from sparseml.transformers.sparsification import Trainer, TrainingArguments\n",
    "import numpy as np\n",
    "from transformers import (\n",
    "    AutoModelForTokenClassification,\n",
    "    AutoConfig, \n",
    "    AutoTokenizer,\n",
    "    EvalPrediction,\n",
    "    DataCollatorForTokenClassification,\n",
    "    PreTrainedTokenizerFast\n",
    ")\n",
    "from datasets import load_dataset, load_metric"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "A6GwDnLL2Zn_"
   },
   "source": [
    "## **Step 1: Load a Dataset**\n",
    "\n",
    "SparseML is integrated with Hugging Face, so we can use the `datasets` class to load datasets from the Hugging Face hub or from local files.\n",
    "\n",
    "[WNUT Dataset Card](https://huggingface.co/datasets/wnut_17)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "znkiQtum2GBj"
   },
   "outputs": [],
   "source": [
    "# load dataset from HF hub\n",
    "dataset = load_dataset(\"wnut_17\")\n",
    "dataset[\"train\"].to_json(\"wnut_17-train.json\")\n",
    "dataset[\"validation\"].to_json(\"wnut_17-validation.json\")\n",
    "\n",
    "# alternatively, load from JSONL file\n",
    "data_files = {}\n",
    "data_files[\"train\"] = \"wnut_17-train.json\"\n",
    "data_files[\"validation\"] = \"wnut_17-validation.json\"\n",
    "dataset_from_json = load_dataset('json', data_files=data_files)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "PeCdZ08a5BhT"
   },
   "source": [
    "We can see the input is a `tokens` which is a list of words and the labels are `ner_tags` which are integers corresponding to a tag type."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "uhK9inmZ47L9"
   },
   "outputs": [],
   "source": [
    "!head wnut_17-train.json --lines=10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2ju94P_l8-q5"
   },
   "source": [
    "## **Step 2: Setup Evaluation Metric**\n",
    "\n",
    "WNUT is a NER task. We will use the [seqeval](https://huggingface.co/spaces/evaluate-metric/seqeval) metric to evaluate the accuracy of the pipeline. \n",
    "\n",
    "The seqeval metric needs to be passed tags rather than tag indexes, so we need to create a mapping between the indexes and the tags so that we can pass the tags to the seqeval metric.\n",
    "\n",
    "Per the [WNUT dataset card](https://huggingface.co/datasets/wnut_17), the NER tags map to the following classes:\n",
    "\n",
    "```\n",
    "{\n",
    "  0: \"O\", \n",
    "  1: \"B-corporation\", \n",
    "  2: \"I-corporation\", \n",
    "  3: \"B-creative-work\", \n",
    "  4: \"I-creative-work\", \n",
    "  5: \"B-group\", \n",
    "  6: \"I-group\", \n",
    "  7: \"B-location\", \n",
    "  8: \"I-location\", \n",
    "  9: \"B-person\", \n",
    "  10: \"I-person\", \n",
    "  11: \"B-product\", \n",
    "  12: \"I-product\"\n",
    "}\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "CkvbT1i9p87z"
   },
   "outputs": [],
   "source": [
    "# label mapping\n",
    "LABEL_MAP = {\n",
    "  0: \"O\", \n",
    "  1: \"B-corporation\", \n",
    "  2: \"I-corporation\", \n",
    "  3: \"B-creative-work\", \n",
    "  4: \"I-creative-work\", \n",
    "  5: \"B-group\", \n",
    "  6: \"I-group\", \n",
    "  7: \"B-location\", \n",
    "  8: \"I-location\", \n",
    "  9: \"B-person\", \n",
    "  10: \"I-person\", \n",
    "  11: \"B-product\", \n",
    "  12: \"I-product\"\n",
    "}\n",
    "\n",
    "# other configs\n",
    "INPUT_COL = \"tokens\"\n",
    "LABEL_COL = \"ner_tags\"\n",
    "SPECIAL_TOKEN_ID = -100\n",
    "NUM_LABELS = len(LABEL_MAP)\n",
    "\n",
    "\n",
    "print(dataset_from_json)\n",
    "print(dataset_from_json[\"train\"][0][INPUT_COL])\n",
    "print(dataset_from_json[\"train\"][0][LABEL_COL])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ZZUOmaW1u7C1"
   },
   "outputs": [],
   "source": [
    "# load evaluation metric - seqeval\n",
    "metric = load_metric(\"seqeval\")\n",
    "\n",
    "# setup metrics function\n",
    "def compute_metrics(p: EvalPrediction):\n",
    "  predictions, labels = p\n",
    "  predictions = np.argmax(predictions, axis=2)\n",
    "\n",
    "  # Remove ignored index (special tokens) and convert indexed tags to labels\n",
    "  true_predictions = [\n",
    "    [LABEL_MAP[pred] for (pred, lab) in zip(prediction, label) if lab != SPECIAL_TOKEN_ID]\n",
    "    for prediction, label in zip(predictions, labels)\n",
    "  ]\n",
    "  true_labels = [\n",
    "    [LABEL_MAP[lab] for (_, lab) in zip(prediction, label) if lab != SPECIAL_TOKEN_ID]\n",
    "    for prediction, label in zip(predictions, labels)\n",
    "  ]\n",
    "\n",
    "  # example: results = metrics.compute(predictions=[\"0\", \"B-group\", \"0\"], true_labels=[\"0\", \"B-person\", \"B-group\"])\n",
    "  #   we used the LABEL_MAP to convert the tags (which are integers in wnut) into the corresponding LABEL\n",
    "  results = metric.compute(predictions=true_predictions, references=true_labels)\n",
    "\n",
    "  return {\n",
    "    \"precision\": results[\"overall_precision\"],\n",
    "    \"recall\": results[\"overall_recall\"],\n",
    "    \"f1\": results[\"overall_f1\"],\n",
    "    \"accuracy\": results[\"overall_accuracy\"],\n",
    "  }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "zERYC_SL4J2A"
   },
   "source": [
    "## **Step 3: Train the Teacher**\n",
    "\n",
    "To support the sparse transfer learning process, we will first train a dense teacher model from scratch - which we can then distill during the sparse transfer learning process.\n",
    "\n",
    "In this case, we will use BERT base."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "235nhfjuAade"
   },
   "outputs": [],
   "source": [
    "# load teacher model and tokenizer\n",
    "TEACHER = \"bert-base-uncased\"\n",
    "teacher_config = AutoConfig.from_pretrained(TEACHER, num_labels=NUM_LABELS)\n",
    "tokenizer = AutoTokenizer.from_pretrained(TEACHER)\n",
    "teacher = AutoModelForTokenClassification.from_pretrained(TEACHER, config=teacher_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "sejy51ZBIXSn"
   },
   "source": [
    "### **Tokenize Dataset**\n",
    "\n",
    "Run the tokenizer on the dataset. \n",
    "\n",
    "In this function, we handle the case where an individual word is tokenized into multiple tokens. In particular, we set the `label_id = SPECIAL_TOKEN_ID` for each token besides the first token in a word. \n",
    "\n",
    "When evaluating the accuracy with `compute_metrics` (defined above), we filter out tokens with `SPECIAL_TOKEN_ID`, such that each word counts only once in the precision and recall calculations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "v-kYPGogA4XN"
   },
   "outputs": [],
   "source": [
    "def preprocess_fn(examples):\n",
    "  tokenized_inputs = tokenizer(\n",
    "    examples[INPUT_COL], \n",
    "    padding=\"max_length\", \n",
    "    max_length=min(tokenizer.model_max_length, 128), \n",
    "    truncation=True,\n",
    "    is_split_into_words=True # the texts in our dataset are lists of words (with a label for each word)\n",
    "  )\n",
    "  \n",
    "  labels = []\n",
    "  for i, label in enumerate(examples[LABEL_COL]):\n",
    "    word_ids = tokenized_inputs.word_ids(batch_index=i)\n",
    "    previous_word_idx = None\n",
    "    label_ids = []\n",
    "\n",
    "    for word_idx in word_ids:\n",
    "      # Special tokens have a word id that is None. We set the label to SPECIAL_TOKEN_ID\n",
    "      # so they are automatically ignored in the loss function.\n",
    "      if word_idx is None:\n",
    "        label_ids.append(SPECIAL_TOKEN_ID)\n",
    "\n",
    "      # We set the label for the first token of each word.\n",
    "      elif word_idx != previous_word_idx:\n",
    "        label_ids.append(label[word_idx])\n",
    "\n",
    "      # We will not label the other tokens of a work, so set to SPECIAL_TOKEN_ID\n",
    "      else:\n",
    "        label_ids.append(SPECIAL_TOKEN_ID)\n",
    "      previous_word_idx = word_idx\n",
    "    labels.append(label_ids)\n",
    "\n",
    "  tokenized_inputs[\"labels\"] = labels\n",
    "  return tokenized_inputs\n",
    "\n",
    "# tokenize the dataset\n",
    "tokenized_dataset = dataset_from_json.map(\n",
    "    preprocess_fn,\n",
    "    batched=True,\n",
    "    desc=\"Running tokenizer on dataset\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pTb5lf-kJ_st"
   },
   "source": [
    "### **Teacher Training: Fine-Tune the Teacher**\n",
    "\n",
    "We use the native Hugging Face `Trainer` (which we import as `HFTrainer`) to train the model. Check out the [Hugging Face documentation](https://huggingface.co/docs/transformers/main_classes/trainer) for more details on the `Trainer` as needed.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "G9BFs-NUBMhA"
   },
   "outputs": [],
   "source": [
    "from transformers import Trainer as HFTrainer\n",
    "from transformers import TrainingArguments as HFTrainingArguments\n",
    "\n",
    "# setup trainer arguments\n",
    "teacher_training_args = HFTrainingArguments(\n",
    "    output_dir=\"./teacher_training\",\n",
    "    do_train=True,\n",
    "    do_eval=True,\n",
    "    num_train_epochs=20.0,\n",
    "    learning_rate=5e-5,\n",
    "    lr_scheduler_type=\"linear\",\n",
    "    logging_strategy=\"epoch\",\n",
    "    evaluation_strategy=\"epoch\",\n",
    "    save_strategy=\"epoch\",\n",
    "    save_total_limit=1,\n",
    "    per_device_train_batch_size=32,\n",
    "    per_device_eval_batch_size=32)\n",
    "\n",
    "# initialize trainer\n",
    "teacher_trainer = HFTrainer(\n",
    "    model=teacher,\n",
    "    args=teacher_training_args,\n",
    "    train_dataset=tokenized_dataset[\"train\"],\n",
    "    eval_dataset=tokenized_dataset[\"validation\"],\n",
    "    tokenizer=tokenizer,\n",
    "    data_collator=DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8 if teacher_training_args.fp16 else None),\n",
    "    compute_metrics=compute_metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "gej5zPQMBm2Q"
   },
   "outputs": [],
   "source": [
    "# run training\n",
    "%rm -rf ./teacher_training\n",
    "teacher_trainer.train(resume_from_checkpoint=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5Alkqwl0I4kR"
   },
   "source": [
    "## **Step 4: Sparse Transfer Learning**\n",
    "\n",
    "Now that we have the teacher trained, we can sparse transfer learn from the pre-sparsified version of BERT with distillation support. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1GEhYi53HoAH"
   },
   "source": [
    "First, we need to select a sparse checkpoint to begin the training process. In this case, we are fine-tuning a 90% pruned version of BERT onto the TweetEval Emotion dataset. This model is available in SparseZoo, identified by the following stub:\n",
    "```\n",
    "zoo:nlp/masked_language_modeling/obert-base/pytorch/huggingface/wikipedia_bookcorpus/pruned90-none\n",
    "```\n",
    "\n",
    "Next, we need to create a sparsification recipe for usage in the training process. Recipes are YAML files that encode the sparsity related algorithms and parameters to be applied by SparseML. For Sparse Transfer Learning, we need to use a recipe that instructs SparseML to maintain sparsity during the training process and to apply quantization over the final few epochs. \n",
    "\n",
    "In SparseZoo, there is a transfer recipe which was used to fine-tune BERT onto the CONLL2003 task (which is also a NER task). Since the WNUT dataset is a similiar problem to CONLL, we will use the CONLL recipe, which is identified by the following stub:\n",
    "\n",
    "```\n",
    "zoo:nlp/token_classification/obert-base/pytorch/huggingface/conll2003/pruned90_quant-none\n",
    "```\n",
    "\n",
    "Use the `sparsezoo` python client to download the models and recipe using their SparseZoo stubs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Ykg8fEN2Q5o_"
   },
   "outputs": [],
   "source": [
    "# 90% pruned upstream BERT trained on MLM objective (pruned90)\n",
    "model_stub = \"zoo:nlp/masked_language_modeling/obert-base/pytorch/huggingface/wikipedia_bookcorpus/pruned90-none\" \n",
    "model_path = Model(model_stub, download_path=\"./model\").training.path\n",
    "\n",
    "# sparse transfer learning recipe for conll2003 (pruned90_quant)\n",
    "transfer_stub = \"zoo:nlp/token_classification/obert-base/pytorch/huggingface/conll2003/pruned90_quant-none\"\n",
    "recipe_path = Model(transfer_stub, download_path=\"./transfer_recipe\").recipes.default.path"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "RLe8iEWxV_zz"
   },
   "source": [
    "We can see that the upstream model (trained on Wikipedia BookCorpus) and  configuration files have been downloaded to the local directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "0NTVj1kPRSCW"
   },
   "outputs": [],
   "source": [
    "%ls ./model/training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "orjvrvdCWEUi"
   },
   "source": [
    "We can see that a transfer learning recipe has been downloaded. The `ConstantPruningModifier` instructs SparseML to maintain the sparsity structure of the network as the model trains and the `QuantizationModifier` instructs SparseML to run Quantization Aware Training at the end of training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "eUYg-7eBRT5f"
   },
   "outputs": [],
   "source": [
    "%cat ./transfer_recipe/recipe/recipe_original.md"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Inspecting the Recipe\n",
    "\n",
    "Here is the transfer learning recipe:\n",
    "\n",
    "```yaml\n",
    "version: 1.1.0\n",
    "\n",
    "# General Variables\n",
    "num_epochs: 13\n",
    "init_lr: 1.5e-4 \n",
    "final_lr: 0\n",
    "\n",
    "qat_start_epoch: 8.0\n",
    "observer_epoch: 12.0\n",
    "quantize_embeddings: 1\n",
    "\n",
    "distill_hardness: 1.0\n",
    "distill_temperature: 2.0\n",
    "\n",
    "# Modifiers:\n",
    "\n",
    "training_modifiers:\n",
    "  - !EpochRangeModifier\n",
    "      end_epoch: eval(num_epochs)\n",
    "      start_epoch: 0.0\n",
    "\n",
    "  - !LearningRateFunctionModifier\n",
    "      start_epoch: 0\n",
    "      end_epoch: eval(num_epochs)\n",
    "      lr_func: linear\n",
    "      init_lr: eval(init_lr)\n",
    "      final_lr: eval(final_lr)\n",
    "    \n",
    "quantization_modifiers:\n",
    "  - !QuantizationModifier\n",
    "      start_epoch: eval(qat_start_epoch)\n",
    "      disable_quantization_observer_epoch: eval(observer_epoch)\n",
    "      freeze_bn_stats_epoch: eval(observer_epoch)\n",
    "      quantize_embeddings: eval(quantize_embeddings)\n",
    "      quantize_linear_activations: 0\n",
    "      exclude_module_types: ['LayerNorm']\n",
    "      submodules:\n",
    "        - bert.embeddings\n",
    "        - bert.encoder\n",
    "        - classifier\n",
    "\n",
    "distillation_modifiers:\n",
    "  - !DistillationModifier\n",
    "     hardness: eval(distill_hardness)\n",
    "     temperature: eval(distill_temperature)\n",
    "     distill_output_keys: [logits]\n",
    "\n",
    "constant_modifiers:\n",
    "  - !ConstantPruningModifier\n",
    "      start_epoch: 0.0\n",
    "      params: __ALL_PRUNABLE__\n",
    "```\n",
    "\n",
    "\n",
    "The `Modifiers` in the transfer learning recipe are the important items that encode how SparseML should modify the training process for Sparse Transfer Learning:\n",
    "- `ConstantPruningModifier` tells SparseML to pin weights at 0 over all epochs, maintaining the sparsity structure of the network\n",
    "- `QuantizationModifier` tells SparseML to quanitze the weights with quantization aware training over the last 5 epochs\n",
    "- `DistillationModifier` tells SparseML how to apply distillation during the trainign process, targeting the logits\n",
    "\n",
    "Below, SparseML's `Trainer` will parses the modifiers and updates the training process to implement the algorithms specified here."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "CCTxNPD4XxHs"
   },
   "source": [
    "Next, we will set up the Hugging Face `tokenizer, config, and model`. Since the tokenizer for the teacher and student are the same, we can use the same `tokenizer` and `dataset` used to train the teacher abve.\n",
    "\n",
    "These are all native Hugging Face objects, so check out the Hugging Face docs for more details on `AutoModel`, `AutoConfig`, and `AutoTokenizer` as needed. \n",
    "\n",
    "We instantiate these classes by passing the local path to the directory containing the `pytorch_model.bin`, `tokenizer.json`, and `config.json` files from the SparseZoo download."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "dhN1oGcTQ9RE"
   },
   "outputs": [],
   "source": [
    "# note: since the teacher and student have the same tokenizer, we can use the one from the teacher training\n",
    "\n",
    "# initialize config\n",
    "config = AutoConfig.from_pretrained(model_path, num_labels=NUM_LABELS)\n",
    "\n",
    "# initialize model\n",
    "model_kwargs = {\"config\": config}\n",
    "model_kwargs[\"state_dict\"], s_delayed = SparseAutoModel._loadable_state_dict(model_path)\n",
    "model = AutoModelForTokenClassification.from_pretrained(model_path, **model_kwargs,)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "g5jB7s1bM4Qh"
   },
   "source": [
    "### **Sparse Transfer Learning: Fine-Tune the Model**\n",
    "\n",
    "SparseML has a custom `Trainer` class that inherits from the [Hugging Face `Trainer` Class](https://huggingface.co/docs/transformers/main_classes/trainer). As such, the SparseML `Trainer` has all of the existing functionality of the HF trainer. However, in addition, we can supply a `recipe` and (optionally) a `teacher`. \n",
    "\n",
    "\n",
    "As we saw above, the `recipe` encodes the sparsity related algorithms and hyperparameters of the training process in a YAML file. The SparseML `Trainer` parses the `recipe` and adjusts the training workflow to apply the algorithms in the recipe. We use the `recipe_args` function to modify the recipe slightly (training for more epochs than used for Conll2003).\n",
    "\n",
    "\n",
    "The `teacher` is an optional argument that instructs SparseML to apply model distillation to support the training process. We pass the teacher rained abobve here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "34IXj1n6RCgQ"
   },
   "outputs": [],
   "source": [
    "# setup trainer arguments\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=\"./training_output\",\n",
    "    do_train=True,\n",
    "    do_eval=True,\n",
    "    resume_from_checkpoint=False,\n",
    "    evaluation_strategy=\"epoch\",\n",
    "    save_strategy=\"epoch\",\n",
    "    logging_strategy=\"epoch\",\n",
    "    save_total_limit=1,\n",
    "    per_device_train_batch_size=32,\n",
    "    per_device_eval_batch_size=32,\n",
    "    fp16=False)\n",
    "\n",
    "# initialize trainer\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    model_state_path=model_path,\n",
    "    recipe=recipe_path,\n",
    "    recipe_args={\n",
    "        \"num_epochs\":25,\n",
    "        \"init_lr\": 5e-5,\n",
    "        \"qat_start_epoch\": 20.0,\n",
    "        \"observer_epoch\": 24.0},\n",
    "    teacher=teacher,\n",
    "    metadata_args=[\"per_device_train_batch_size\",\"per_device_eval_batch_size\",\"fp16\"],\n",
    "    args=training_args,\n",
    "    train_dataset=tokenized_dataset[\"train\"],\n",
    "    eval_dataset=tokenized_dataset[\"validation\"],\n",
    "    tokenizer=tokenizer,\n",
    "    data_collator=DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None),\n",
    "    compute_metrics=compute_metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "LLBgAdqyRDro"
   },
   "outputs": [],
   "source": [
    "# step 5: run training\n",
    "%rm -rf training_output\n",
    "train_result = trainer.train(resume_from_checkpoint=False)\n",
    "trainer.save_model()\n",
    "trainer.save_state()\n",
    "trainer.save_optimizer_and_scheduler(training_args.output_dir)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## **Step 7: Export To ONNX**\n",
    "\n",
    "Run the following to export the model to ONNX. The script creates a `deployment` folder containing ONNX file and the necessary configuration files (e.g. `tokenizer.json`) for deployment with DeepSparse."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-rhWjiHBeR7M"
   },
   "outputs": [],
   "source": [
    "!sparseml.transformers.export_onnx \\\n",
    "  --model_path training_output \\\n",
    "  --task token_classification"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## **Optional: Deploy with DeepSparse**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-XubpXohO_8A"
   },
   "outputs": [],
   "source": [
    "%pip install deepsparse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "m_USM8mCPETg"
   },
   "outputs": [],
   "source": [
    "from deepsparse import Pipeline\n",
    "\n",
    "pipeline = Pipeline.create(\"token_classification\", model_path=\"./deployment\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Bncg7Xx5ONqB"
   },
   "outputs": [],
   "source": [
    "from pprint import pprint\n",
    "prediction = pipeline(\"Japan, co-hosts of the World Cup in 2002 and ranked 20th in the world by FIFA, are favourites to regain their title here.\")\n",
    "pprint(prediction.predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "4zqbsVpoSZ-R"
   },
   "outputs": [],
   "source": [
    "prediction = pipeline(\"China controlled most of the match and saw several chances missed until the 78th minute when Uzbek striker Igor Shkvyrin took advantage of a misdirected defensive header to lob the ball over the advancing Chinese keeper and into an empty net.\")\n",
    "pprint(prediction.predictions)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "authorship_tag": "ABX9TyPeTTOJcgeFtheavt8anQmB",
   "provenance": [
    {
     "file_id": "1cXfeYQ_ZbnJRoQsaYOIDR2N7YP--mMiL",
     "timestamp": 1677358343826
    },
    {
     "file_id": "1Zawa0sifXr2wIl9tbF7ySJ7xYY0dtTzI",
     "timestamp": 1677345946788
    }
   ]
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "vscode": {
   "interpreter": {
    "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
